# 导入VanillaVAE
from models import VanillaVAE
import torch
from torch.utils.tensorboard import SummaryWriter

# 创建VanillaVAE对象
vae = VanillaVAE(3, 128)

print(vae)

# 为了测试，我们随机生成一些数据
x = torch.randn(16, 3, 28, 28)

# 将数据输入到模型中
decode, input, mu, log_var = vae(x)

with SummaryWriter(log_dir='') as sw:  # 实例化 SummaryWriter ,可以自定义数据输出路径
    sw.add_graph(vae, x)  # 输出网络结构图
    sw.close()  # 关闭  sw
